[PyTorch] Integrate the cuBLAS single GEMM MXFP8 NN, NT support for sm120#3050
Draft
KshitijLakhani wants to merge 8 commits into
Draft
[PyTorch] Integrate the cuBLAS single GEMM MXFP8 NN, NT support for sm120#3050KshitijLakhani wants to merge 8 commits into
KshitijLakhani wants to merge 8 commits into
Conversation
Adds NVTE_ENABLE_MXFP8_SM120 environment variable to unblock MXFP8 testing on sm120 (compute capability 12.0) devices. Default behavior unchanged; MXFP8 remains gated off on sm120 without explicit opt-in since not all GEMM layouts are currently supported. Also adds tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py: a focused layout x shape x dtype matrix exercising MXFP8 single GEMM via the underlying general_gemm call directly. The TN layout is exercised across small/medium/transformer-sized shapes and BF16/FP32 outputs. NN and NT layouts on sm120 are marked strict-xfail; the suite will fail-on-XPASS once full layout support is added so the markers can be removed.
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…13.6+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Remove tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py. The TN/NN/NT MXFP8 GEMM code paths it was added to localize are already exercised end-to-end on sm_120 (with cuBLASLt >= 13.6.0.2) by the existing te.Linear / te.LayerNormLinear / te.GroupedLinear / te.TransformerLayer numerics tests in tests/pytorch/test_numerics.py via the MXFP8BlockScaling entry in fp8_recipes (each Linear forward + backward dispatches the three cuBLAS GEMMs as fwd=TN, dgrad=NN, wgrad=NT). The runtime _compute_mxfp8_support gate added in the earlier commits on this branch already module-skips MXFP8 below cuBLASLt 13.6.0.2 on sm_120, so the per-layout strict-xfail layer in this file is redundant. Out-of-tree triage material (Testing/repro_mxfp8_layouts.cu and the Testing/repro_mxfp8_layouts.py driver) remains available if a future cuBLAS regression needs layout-localized signal again.
cuBLASLt 13.6.0.2 supports single-GEMM MXFP8 on sm_120 / sm_121 but not the grouped variant. Route general_grouped_gemm and general_grouped_gemm_for_grouped_tensor through check_mxfp8_grouped_gemm_support() and raise NotImplementedError when unsupported, instead of failing opaquely inside cuBLAS. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Introduce _compute_mxfp8_grouped_gemm_support / check_mxfp8_grouped_gemm_support and a public is_mxfp8_grouped_gemm_available helper so callers (te.GroupedLinear, general_grouped_gemm[_for_grouped_tensor], and grouped-GEMM tests) can gate on grouped MXFP8 separately from single-GEMM MXFP8. On sm_120 / sm_121, cuBLASLt 13.6.0.2 enables single MXFP8 GEMM (TN/NN/NT) but not the grouped variant; the new gate returns False there with a descriptive reason. Also widen the single-GEMM gate to sm_121 alongside sm_120. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
- Probe is_mxfp8_grouped_gemm_available in test_fusible_ops, test_numerics, and test_sanity, and pytest.skip MXFP8 grouped_linear / padding_grouped_linear / grouped_gemm cases (plus a maybe_skip_quantization_for_grouped_gemm helper in test_fusible_ops) with the gate's reason. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
ffa0eab to
0370659
Compare
for more information, see https://pre-commit.ci
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Enable MXFP8 support for NT, NN single GEMMs via cuBLAS for sm120.
fwd=TN, dgrad=NN, wgrad=NT.
Fixes #2668
This PR is complimentary to: #2833
It would be best to merge PR 3050 along with (ideally, right after) PR 2833.
The reason is that PR 2833 goes hand in hand with a CI PR that enables sm120 in the CI as well. Hence, it would be best to merge 3050 when sm120 in CI is available and not prior.
TODO: Code and git history clean up
Type of change
Changes
_compute_mxfp8_support()to enable usage of cuBLAS MXFP8 NN and NT single GEMMs in TE_compute_mxfp8_grouped_gemm_support()to differentiate support between single GEMMs (_compute_mxfp8_support()) and grouped GEMM (_compute_mxfp8_grouped_gemm_support()) - the latter is not supported in cuBLAS yetChecklist: